{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b02eec08",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ad7e61b8dafe472398571ec163dbd224",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import LlamaForCausalLM, LlamaTokenizer\n",
    "\n",
    "tokenizer = LlamaTokenizer.from_pretrained(\"../KoAlpaca/\")\n",
    "model = LlamaForCausalLM.from_pretrained(\"../KoAlpaca/\").to('cuda:0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "660839b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[    1, 29871, 31734,   238,   136,   152, 30944, 31578, 31527, 29973,\n",
       "         29871,   239,   163,   131, 31081, 29871,   239,   134,   139, 30906]],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.generate(**tokenizer('안녕하세요?', return_tensors='pt').to('cuda:0'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "8b80ea26",
   "metadata": {},
   "outputs": [],
   "source": [
    "PROMPT_DICT = {\n",
    "    \"prompt_input\": (\n",
    "        \"Below is an instruction that describes a task, paired with an input that provides further context.\\n\"\n",
    "        \"아래는 작업을 설명하는 명령어와 추가적 맥락을 제공하는 입력이 짝을 이루는 예제입니다.\\n\\n\"\n",
    "        \"Write a response that appropriately completes the request.\\n요청을 적절히 완료하는 응답을 작성하세요.\\n\\n\"\n",
    "        \"### Instruction(명령어):\\n{instruction}\\n\\n### Input(입력):\\n{input}\\n\\n### Response(응답):\"\n",
    "    ),\n",
    "    \"prompt_no_input\": (\n",
    "        \"Below is an instruction that describes a task.\\n\"\n",
    "        \"아래는 작업을 설명하는 명령어입니다.\\n\\n\"\n",
    "        \"Write a response that appropriately completes the request.\\n명령어에 따른 요청을 적절히 완료하는 응답을 작성하세요.\\n\\n\"\n",
    "        \"### Instruction(명령어):\\n{instruction}\\n\\n### Response(응답):\"\n",
    "    ),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "b3cdda6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen(prompt, user_input=None, max_new_tokens=128, temperature=0.5):\n",
    "    if user_input:\n",
    "        x = PROMPT_DICT['prompt_input'].format(instruction=prompt, input=user_input)\n",
    "    else:\n",
    "        x = PROMPT_DICT['prompt_no_input'].format(instruction=prompt)\n",
    "    \n",
    "    input_ids = tokenizer.encode(x, return_tensors=\"pt\").to('cuda:0')\n",
    "    gen_tokens = model.generate(\n",
    "        input_ids, \n",
    "        max_new_tokens=max_new_tokens, \n",
    "        num_return_sequences=1, \n",
    "        temperature=temperature,\n",
    "        no_repeat_ngram_size=6,\n",
    "        do_sample=True,\n",
    "    )\n",
    "    gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)\n",
    "    \n",
    "    return gen_text.replace(x, '')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "a2ff94e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " ```python\n",
      "import datetime\n",
      "\n",
      "now = datetime.datetime.now()\n",
      "print(now.strftime(\"%Y-%m-%d %H:%M:%S\"))\n",
      "```\n"
     ]
    }
   ],
   "source": [
    "# Example usage:\n",
    "prompt = \"Python으로 uptime을 찾는 코드\"\n",
    "generated_text = gen(prompt)\n",
    "print(generated_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5b4f7e9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
